import btk
import unittest
import _TDDConfigure
import numpy

class AnalogOffsetRemoverTest(unittest.TestCase):
  def test_Test0(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,1,1)
    data = numpy.empty([25,1]);
    data.fill(5.0)
    acq.GetAnalog(0).SetValues(data)
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 1)
    self.assertEqual(output.GetAnalog(0).GetValues().sum(), 0.0)
  
  def test_Test1234(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,4,1)
    data = numpy.empty([25,1]);
    data.fill(5.0)
    acq.GetAnalog(0).SetValues(data)
    data.fill(4.0)
    acq.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq.GetAnalog(2).SetValues(data)
    data.fill(2.0)
    acq.GetAnalog(3).SetValues(data)
    acq2 = btk.btkAcquisition()
    acq2.Init(0,25,4,1)
    data.fill(1.0)
    acq2.GetAnalog(0).SetValues(data)
    data.fill(2.0)
    acq2.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq2.GetAnalog(2).SetValues(data)
    data.fill(4.0)
    acq2.GetAnalog(3).SetValues(data)
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq2)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 4)
    self.assertEqual(output.GetAnalog(0).GetValues().sum() / 25.0, 4.0)
    self.assertEqual(output.GetAnalog(1).GetValues().sum() / 25.0, 2.0)
    self.assertEqual(output.GetAnalog(2).GetValues().sum() / 25.0, 0.0)
    self.assertEqual(output.GetAnalog(3).GetValues().sum() / 25.0, -2.0)
  
  def test_Test3Over4(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,4,1)
    data = numpy.empty([25,1]);
    data.fill(5.0)
    acq.GetAnalog(0).SetValues(data)
    data.fill(4.0)
    acq.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq.GetAnalog(2).SetValues(data)
    data.fill(2.0)
    acq.GetAnalog(3).SetValues(data)
    acq2 = btk.btkAcquisition()
    acq2.Init(0,25,4,1)
    data.fill(1.0)
    acq2.GetAnalog(0).SetValues(data)
    data.fill(2.0)
    acq2.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq2.GetAnalog(2).SetValues(data)
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq2)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 4)
    self.assertEqual(output.GetAnalog(0).GetValues().sum() / 25.0, 4.0)
    self.assertEqual(output.GetAnalog(1).GetValues().sum() / 25.0, 2.0)
    self.assertEqual(output.GetAnalog(2).GetValues().sum() / 25.0, 0.0)
    self.assertEqual(output.GetAnalog(3).GetValues().sum() / 25.0, 2.0)
  
  def test_Test4Over3(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,3,1)
    data = numpy.empty([25,1]);
    data.fill(5.0)
    acq.GetAnalog(0).SetValues(data)
    data.fill(4.0)
    acq.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq.GetAnalog(2).SetValues(data)
    acq2 = btk.btkAcquisition()
    acq2.Init(0,25,4,1)
    data.fill(1.0)
    acq2.GetAnalog(0).SetValues(data)
    data.fill(2.0)
    acq2.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq2.GetAnalog(2).SetValues(data)
    data.fill(4.0)
    acq2.GetAnalog(3).SetValues(data)
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq2)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 3)
    self.assertEqual(output.GetAnalog(0).GetValues().sum() / 25.0, 4.0)
    self.assertEqual(output.GetAnalog(1).GetValues().sum() / 25.0, 2.0)
    self.assertEqual(output.GetAnalog(2).GetValues().sum() / 25.0, 0.0)
  
  def test_TestNoCommonLabel(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,3,1)
    acq.GetAnalog(0).SetLabel("FOO")
    acq.GetAnalog(1).SetLabel("BAR")
    acq.GetAnalog(2).SetLabel("FOOBAR")
    acq2 = btk.btkAcquisition()
    acq2.Init(0,25,3,1)
    data = numpy.empty([25,1]);
    data.fill(1.0)
    acq2.GetAnalog(0).SetValues(data)
    data.fill(2.0)
    acq2.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq2.GetAnalog(2).SetValues(data)
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq2)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 3)
    self.assertEqual(output.GetAnalog(0).GetValues().sum(), 0.0)
    self.assertEqual(output.GetAnalog(1).GetValues().sum(), 0.0)
    self.assertEqual(output.GetAnalog(2).GetValues().sum(), 0.0)
  
  def test_TestOneCommonLabel(self):
    acq = btk.btkAcquisition()
    acq.Init(0,25,3,1)
    acq.GetAnalog(0).SetLabel("FOO")
    acq.GetAnalog(1).SetLabel("BAR")
    acq.GetAnalog(2).SetLabel("FOOBAR")
    acq2 = btk.btkAcquisition()
    acq2.Init(0,25,3,1)
    data = numpy.empty([25,1]);
    data.fill(1.0)
    acq2.GetAnalog(0).SetValues(data)
    data.fill(2.0)
    acq2.GetAnalog(1).SetValues(data)
    data.fill(3.0)
    acq2.GetAnalog(2).SetValues(data)
    acq2.GetAnalog(2).SetLabel("FOO")
    remover = btk.btkAnalogOffsetRemover()
    remover.SetRawInput(acq)
    remover.SetOffsetInput(acq2)
    remover.Update()
    output = remover.GetOutput()
    self.assertEqual(output.GetAnalogNumber(), 3)
    self.assertEqual(output.GetAnalog(0).GetValues().sum() / 25.0, -3.0)
    self.assertEqual(output.GetAnalog(1).GetValues().sum() / 25.0, 0.0)
    self.assertEqual(output.GetAnalog(2).GetValues().sum() / 25.0, 0.0)
  